{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8e808dee-afa0-4200-bb5b-6f384f7a508a",
   "metadata": {},
   "source": [
    "# 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5c1da73-107f-45ba-958f-3b30dc0b5f32",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sentence_transformers.evaluation import (\n",
    "    InformationRetrievalEvaluator,\n",
    "    SequentialEvaluator,\n",
    ")\n",
    "import pandas as pd\n",
    "from sentence_transformers.util import cos_sim\n",
    "from datasets import load_dataset, concatenate_datasets\n",
    "from sentence_transformers import SentenceTransformerTrainer\n",
    "\n",
    "model_id = \"/nfs_data_new/wp/weights/bge-small-zh\" \n",
    " \n",
    "model = SentenceTransformer(\n",
    "    model_id, device=\"cuda\" if torch.cuda.is_available() else \"cpu\",model_kwargs={'ignore_mismatched_sizes':True}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81df7e4a-97b6-49c7-866b-35c65c99113f",
   "metadata": {},
   "source": [
    "# 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6a6c0e5-79ee-43e1-b91d-88f30a3327fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset  = load_dataset(\"csv\", data_files=\"embedding_data.csv\")['train']\n",
    "dataset = dataset.rename_column(\"context\", \"positive\")\n",
    "dataset = dataset.rename_column(\"query\", \"anchor\")\n",
    "\n",
    "dataset = dataset.train_test_split(test_size=0.2, shuffle=True)\n",
    "\n",
    "train_dataset = dataset['train']\n",
    "test_dataset = dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7bc66613-3009-440d-97d3-0670c1fe193a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcd9763c-ae30-4ece-b79f-7c9ffe95e08f",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = list(set([_['positive'] for _ in test_dataset]))\n",
    "queries = list(set([_['anchor'] for _ in test_dataset]))\n",
    "\n",
    "id2corpus = dict(zip(range(len(corpus)),corpus))\n",
    "id2corpus = {str(i):j for i,j in id2corpus.items()}\n",
    "\n",
    "id2queries = dict(zip(range(len(queries)),queries))\n",
    "id2queries = {str(i):j for i,j in id2queries.items()}\n",
    "relevant_docs = {_['anchor']:_['positive'] for _ in test_dataset}\n",
    "relevant_docs = {str(queries.index(i)):str(corpus.index(j)) for i,j in relevant_docs.items()}\n",
    "\n",
    "  \n",
    "ir_evaluator = InformationRetrievalEvaluator(\n",
    "    queries=id2queries,\n",
    "    corpus=id2corpus,\n",
    "    relevant_docs=relevant_docs,\n",
    "    name=f\"bge_m3\",\n",
    "    score_functions={\"cosine\": cos_sim},\n",
    "batch_size=8\n",
    ")\n",
    " \n",
    "evaluator = SequentialEvaluator([ir_evaluator])\n",
    "results = evaluator(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "953fd9ea-6885-4ac7-9bb9-b16282983e19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers.losses import MultipleNegativesRankingLoss\n",
    "train_loss = MultipleNegativesRankingLoss(model)\n",
    "\n",
    "\n",
    "from sentence_transformers import SentenceTransformerTrainingArguments\n",
    "from sentence_transformers.training_args import BatchSamplers\n",
    " \n",
    "\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"bge-m3-ft\", \n",
    "    num_train_epochs=4,\n",
    "    per_device_train_batch_size=16,\n",
    "    gradient_accumulation_steps=2,\n",
    "    per_device_eval_batch_size=16,\n",
    "    warmup_ratio=0.1,\n",
    "    learning_rate=2e-5,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    optim=\"adamw_torch_fused\",\n",
    "    tf32=True,\n",
    "    bf16=True,\n",
    "    batch_sampler=BatchSamplers.NO_DUPLICATES,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_steps=10,\n",
    "    save_total_limit=3,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"bge_m3_cosine_ndcg@10\"\n",
    ")\n",
    "\n",
    "\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=model,\n",
    "    args=args,  \n",
    "    train_dataset=train_dataset.select_columns(\n",
    "        [\"positive\", \"anchor\"]\n",
    "    ), \n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d067275-04ee-4725-90bf-317ad5a941e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82a64965-f9c6-4094-a8c3-af53ca3d4443",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0rc2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
